# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

from loguru import logger

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.common.utility_functions import comp_pcc, is_blackhole, run_for_blackhole
from tests.ttnn.unit_tests.test_bh_20_cores_sharding import skip_if_not_blackhole_20_cores


# Helper function to get welford parameters based on device type
def get_welford_params():
    """Return welford parameters - only legacy mode for Blackhole, both modes for other devices"""
    if is_blackhole():
        return (False,), ("legacy",)
    else:
        return (True, False), ("welford", "legacy")


welford_flavors, welford_ids = get_welford_params()


# for debug purpose
def manual_group_norm(input_tensor, num_groups, eps=1e-2):
    N, C, H, W = input_tensor.shape
    assert C % num_groups == 0, "Number of channels must be divisible by number of groups"

    # Reshape into groups
    group_channels = C // num_groups
    input_tensor = input_tensor.view(N, num_groups, group_channels, H, W)

    # Calculate mean and variance
    mean = input_tensor.mean(dim=(2, 3, 4), keepdim=True)
    var = input_tensor.var(dim=(2, 3, 4), keepdim=True)

    # Normalize
    input_tensor = (input_tensor - mean) / torch.sqrt(var + eps)

    # Reshape back to original dimensions
    input_tensor = input_tensor.view(N, C, H, W)
    return input_tensor


@pytest.mark.parametrize("N", [1])
@pytest.mark.parametrize("C", [320])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("W", [32])
@pytest.mark.parametrize("num_groups", [32])
@pytest.mark.parametrize("use_welford", welford_flavors, ids=welford_ids)
def test_group_norm_with_height_sharded(device, N, C, H, W, num_groups, use_welford):
    torch.manual_seed(0)

    grid_size = ttnn.CoreGrid(y=1, x=8)

    torch_input_tensor = torch.rand((N, C, H, W), dtype=torch.bfloat16)
    torch_weight = torch.rand((C,), dtype=torch.bfloat16)
    torch_bias = torch.rand((C,), dtype=torch.bfloat16)
    torch_output_tensor = torch.nn.functional.group_norm(
        torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias
    )
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
    input_tensor = ttnn.from_torch(
        input_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # input mask
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_size.y)
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    gamma = ttnn.create_group_norm_weight_bias_rm(torch_weight, C, grid_size.y)
    beta = ttnn.create_group_norm_weight_bias_rm(torch_bias, C, grid_size.y)

    gamma_t = ttnn.from_torch(
        gamma,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    beta_t = ttnn.from_torch(
        beta,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = N * H * W // grid_size.x, C // grid_size.y
    shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR)
    sharded_mem_config = ttnn.MemoryConfig(
        ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.types.BufferType.L1, shard_spec
    )
    input_tensor = ttnn.to_memory_config(input_tensor, sharded_mem_config)

    output_tensor = ttnn.group_norm(
        input_tensor,
        num_groups=num_groups,
        input_mask=input_mask_tensor,
        weight=gamma_t,
        bias=beta_t,
        memory_config=sharded_mem_config,
        core_grid=grid_size,
        use_welford=use_welford,
    )

    output_tensor = ttnn.to_memory_config(output_tensor, ttnn.DRAM_MEMORY_CONFIG)
    output_tensor = ttnn.from_device(output_tensor)
    output_tensor = ttnn.to_torch(output_tensor)

    assert_with_pcc(torch_output_tensor, output_tensor, 0.9997 if use_welford else 0.9998)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 0}], indirect=True)
@pytest.mark.parametrize(
    "N, C, H, W, num_groups",
    [
        (1, 1280, 16, 16, 32),
        (1, 320, 1, 8192, 32),
        (1, 960, 1, 1024, 32),
        # not fit in L1 for GS
        # (1, 960, 1, 4096, 32),
    ],
)
@pytest.mark.parametrize("use_welford", welford_flavors, ids=welford_ids)
def test_group_norm_with_block_sharded_v2_8x4_grid(device, N, C, H, W, num_groups, use_welford):
    torch.manual_seed(0)

    grid_size = ttnn.CoreGrid(y=4, x=8)

    # torch input tensor
    torch_input_tensor = torch.rand((N, C, H, W), dtype=torch.bfloat16)
    torch_weight = torch.ones((C,), dtype=torch.bfloat16)
    torch_bias = torch.zeros((C,), dtype=torch.bfloat16)
    torch_output_tensor = torch.nn.functional.group_norm(
        torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias
    )
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    # input tensor
    input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
    input_tensor = ttnn.from_torch(
        input_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.L1_MEMORY_CONFIG,
    )

    # input mask
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_size.y)
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # gamma/beta
    gamma = ttnn.create_group_norm_weight_bias_rm(torch_weight, C, grid_size.y)
    beta = ttnn.create_group_norm_weight_bias_rm(torch_bias, C, grid_size.y)

    gamma_t = ttnn.from_torch(
        gamma,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    beta_t = ttnn.from_torch(
        beta,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = N * H * W // grid_size.x, C // grid_size.y
    shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR)
    sharded_mem_config = ttnn.MemoryConfig(
        ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
    )
    input_tensor = ttnn.to_memory_config(input_tensor, sharded_mem_config)

    # groupnorm
    output_tensor = ttnn.group_norm(
        input_tensor,
        num_groups=num_groups,
        input_mask=input_mask_tensor,
        weight=gamma_t,
        bias=beta_t,
        memory_config=sharded_mem_config,
        core_grid=grid_size,
        use_welford=use_welford,
    )

    # output tensor
    output_tensor = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG)
    output_tensor = ttnn.from_device(output_tensor)
    output_tensor = ttnn.to_torch(output_tensor)

    assert_with_pcc(torch_output_tensor, output_tensor, 0.9997)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 0}], indirect=True)
@pytest.mark.parametrize(
    "N, C, H, W, num_groups",
    [
        (2, 320, 64, 64, 32),
        (1, 640, 1, 2048, 32),
        (1, 640, 1, 4096, 32),
        (1, 960, 1, 2048, 32),
        (1, 960, 1, 4096, 32),
        (1, 1280, 1, 512, 32),
        (1, 1280, 1, 2048, 32),
        (1, 1920, 1, 512, 32),
        (1, 1920, 1, 2048, 32),
        (1, 2560, 1, 512, 32),
        # not fit in L1 for GS
        # (2, 960, 64, 64, 32),
        # (1, 640, 1, 8192, 32),
    ],
)
@pytest.mark.parametrize("use_welford", welford_flavors, ids=welford_ids)
def test_group_norm_with_block_sharded_v2_8x8_grid(device, N, C, H, W, num_groups, use_welford):
    torch.manual_seed(0)
    if device.core_grid.y == 7:
        pytest.skip()

    grid_size = ttnn.CoreGrid(y=8, x=8)

    # torch input tensor
    torch_input_tensor = torch.rand((N, C, H, W), dtype=torch.bfloat16)
    torch_weight = torch.ones((C,), dtype=torch.bfloat16)
    torch_bias = torch.zeros((C,), dtype=torch.bfloat16)
    torch_output_tensor = torch.nn.functional.group_norm(
        torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias
    )
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    # input tensor
    input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
    input_tensor = ttnn.from_torch(
        input_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # input mask
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_size.y)
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # gamma/beta
    gamma = ttnn.create_group_norm_weight_bias_rm(torch_weight, C, grid_size.y)
    beta = ttnn.create_group_norm_weight_bias_rm(torch_bias, C, grid_size.y)

    gamma_t = ttnn.from_torch(
        gamma,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    beta_t = ttnn.from_torch(
        beta,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = N * H * W // grid_size.x, C // grid_size.y
    shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR)
    sharded_mem_config = ttnn.MemoryConfig(
        ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
    )
    input_tensor = ttnn.interleaved_to_sharded(input_tensor, sharded_mem_config, keep_l1_aligned=True)

    # groupnorm
    output_tensor = ttnn.group_norm(
        input_tensor,
        num_groups=num_groups,
        input_mask=input_mask_tensor,
        weight=gamma_t,
        bias=beta_t,
        memory_config=sharded_mem_config,
        core_grid=grid_size,
        use_welford=use_welford,
    )

    # output tensor
    output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG, is_l1_aligned=True)
    output_tensor = ttnn.from_device(output_tensor)
    output_tensor = ttnn.to_torch(output_tensor)

    assert_with_pcc(torch_output_tensor, output_tensor, 0.9997)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 0}], indirect=True)
@pytest.mark.parametrize(
    "N, C, H, W, num_groups",
    [
        (1, 1280, 1, 512, 32),
        (1, 1280, 1, 2048, 32),
        (1, 2560, 1, 512, 32),
    ],
)
@pytest.mark.parametrize("use_welford", welford_flavors, ids=welford_ids)
def test_group_norm_with_block_sharded_v2_8x8_grid_tile_layout(device, N, C, H, W, num_groups, use_welford):
    torch.manual_seed(0)
    if device.core_grid.y == 7:
        pytest.skip()

    grid_size = ttnn.CoreGrid(y=8, x=8)

    # torch input tensor
    torch_input_tensor = torch.rand((N, C, H, W), dtype=torch.bfloat16)
    torch_weight = torch.ones((C,), dtype=torch.bfloat16)
    torch_bias = torch.rand((C,), dtype=torch.bfloat16)
    torch_output_tensor = torch.nn.functional.group_norm(
        torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias
    )
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    # input tensor
    input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
    input_tensor = ttnn.from_torch(
        input_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # input mask
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_size.y)
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # gamma/beta
    gamma = ttnn.create_group_norm_weight_bias_rm(torch_weight, C, grid_size.y)
    beta = ttnn.create_group_norm_weight_bias_rm(torch_bias, C, grid_size.y)

    gamma_t = ttnn.from_torch(
        gamma,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    beta_t = ttnn.from_torch(
        beta,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = N * H * W // grid_size.x, C // grid_size.y
    shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR)
    sharded_mem_config = ttnn.MemoryConfig(
        ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
    )
    input_tensor = ttnn.to_memory_config(input_tensor, sharded_mem_config)

    # groupnorm
    output_tensor = ttnn.group_norm(
        input_tensor,
        num_groups=num_groups,
        input_mask=input_mask_tensor,
        weight=gamma_t,
        bias=beta_t,
        memory_config=sharded_mem_config,
        core_grid=grid_size,
        inplace=False,
        use_welford=use_welford,
    )

    # output tensor
    output_tensor = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG)
    output_tensor = ttnn.from_device(output_tensor)
    output_tensor = ttnn.to_torch(output_tensor)

    assert_with_pcc(torch_output_tensor, output_tensor, 0.9997)


def generate_sdxl_test_inputs():
    inputs = []

    # 1024x1024 resoultion

    # UNet inputs
    inputs.append((1, 1280, 64, 64))
    inputs.append((1, 1280, 32, 32))
    inputs.append((1, 1920, 64, 64))
    inputs.append((1, 1920, 32, 32))
    inputs.append((1, 2560, 32, 32))
    inputs.append((1, 320, 128, 128))
    inputs.append((1, 320, 64, 64))
    inputs.append((1, 640, 64, 64))
    inputs.append((1, 640, 32, 32))
    inputs.append((1, 960, 64, 64))

    # VAE inputs
    inputs.append((1, 512, 128, 128))

    # Refiner UNet inputs
    inputs.append((1, 1152, 64, 64))
    inputs.append((1, 1536, 16, 16))
    inputs.append((1, 1536, 32, 32))
    inputs.append((1, 1536, 64, 64))
    inputs.append((1, 2304, 32, 32))
    inputs.append((1, 2304, 64, 64))
    inputs.append((1, 3072, 16, 16))
    inputs.append((1, 3072, 32, 32))
    inputs.append((1, 384, 128, 128))
    inputs.append((1, 384, 64, 64))
    inputs.append((1, 768, 32, 32))
    inputs.append((1, 768, 64, 64))

    return inputs


@pytest.mark.parametrize("device_params", [{"l1_small_size": 0}], indirect=True)
@pytest.mark.parametrize("input_shape", generate_sdxl_test_inputs())
@pytest.mark.parametrize("use_welford", welford_flavors, ids=welford_ids)
def test_sdxl_base_group_norm(device, input_shape, use_welford):
    num_groups = 32  #  always 32 for SDXL Base 1024x1024
    N, C, H, W = input_shape
    torch.manual_seed(0)
    if device.core_grid.y == 7:
        pytest.skip()

    grid_size = ttnn.CoreGrid(y=8, x=8)

    # Generate torch tensor
    torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16)

    # Execute torch group_norm
    torch_output_tensor = torch.nn.functional.group_norm(torch_input_tensor, num_groups)
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    # Generate ttnn tensor
    tt_input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
    tt_input_tensor = ttnn.from_torch(
        tt_input_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.TILE_LAYOUT if C == 512 else ttnn.ROW_MAJOR_LAYOUT,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
        device=device,
    )

    # Generate input mask
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_size.y)
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # Generate shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = N * H * W // grid_size.x, C // grid_size.y
    shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
    sharded_mem_config = ttnn.MemoryConfig(
        ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
    )
    tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, memory_config=sharded_mem_config)

    # Execute ttnn group_norm
    tt_output_tensor = ttnn.group_norm(
        tt_input_tensor,
        num_groups=num_groups,
        input_mask=input_mask_tensor,
        memory_config=sharded_mem_config,
        core_grid=grid_size,
        inplace=tt_input_tensor.layout != ttnn.TILE_LAYOUT,
        use_welford=use_welford,
    )

    tt_output_tensor = ttnn.from_device(tt_output_tensor)
    tt_output_tensor = ttnn.to_torch(tt_output_tensor)

    assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9997)


def generate_sdxl_test_inputs_neg_mask():
    inputs = []
    inputs.append((1, 640, 128, 128))
    inputs.append((1, 960, 128, 128))
    inputs.append((1, 768, 128, 128))
    return inputs


@pytest.mark.parametrize("device_params", [{"l1_small_size": 47000}], indirect=True)
@pytest.mark.parametrize("input_shape", generate_sdxl_test_inputs_neg_mask())
def test_sdxl_base_group_norm_negative_mask(device, input_shape):
    num_groups = 32  #  always 32 for SDXL Base 1024x1024
    N, C, H, W = input_shape
    torch.manual_seed(0)
    if device.core_grid.y == 7:
        pytest.skip()

    core_x = 8
    core_y = 8
    grid_size = ttnn.CoreGrid(y=core_y, x=core_x)

    # Generate torch tensor
    torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16)
    torch_weight = torch.rand((C,), dtype=torch.bfloat16)
    torch_bias = torch.rand((C,), dtype=torch.bfloat16)

    # Execute torch group_norm
    torch_output_tensor = torch.nn.functional.group_norm(
        torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias
    )
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    # Generate ttnn tensor
    tt_input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
    tt_input_tensor = ttnn.from_torch(
        tt_input_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
    )

    # Generate input mask
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_size.x)
    input_mask_tensor_torch = input_mask_tensor
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    input_negative_mask_tensor = ttnn.create_group_norm_input_negative_mask(C, num_groups, grid_size.x)
    input_negative_mask_tensor_torch = input_negative_mask_tensor
    input_negative_mask_tensor = ttnn.from_torch(
        input_negative_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    gamma = ttnn.create_group_norm_weight_bias_rm(torch_weight, C, grid_size.x)
    beta = ttnn.create_group_norm_weight_bias_rm(torch_bias, C, grid_size.x)

    gamma_t = ttnn.from_torch(
        gamma,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    beta_t = ttnn.from_torch(
        beta,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # Generate shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = N * H * W // grid_size.y, C // grid_size.x
    shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
    sharded_mem_config = ttnn.MemoryConfig(
        ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
    )
    tt_input_tensor = ttnn.to_device(tt_input_tensor, device, memory_config=sharded_mem_config)

    # Execute ttnn group_norm
    tt_output_tensor = ttnn.group_norm(
        tt_input_tensor,
        num_groups=num_groups,
        input_mask=input_mask_tensor,
        negative_mask=input_negative_mask_tensor,
        memory_config=sharded_mem_config,
        core_grid=grid_size,
        weight=gamma_t,
        bias=beta_t,
    )

    tt_output_tensor = ttnn.from_device(tt_output_tensor)
    tt_output_tensor = ttnn.to_torch(tt_output_tensor)

    assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9997)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 0}], indirect=True)
@pytest.mark.parametrize("N", [1])
@pytest.mark.parametrize("C", [1920])
@pytest.mark.parametrize("H", [64])
@pytest.mark.parametrize("W", [64])
@pytest.mark.parametrize("num_groups", [32])
def test_group_norm_compute_config(device, N, C, H, W, num_groups):
    """
    Test that a high-accuracy compute kernel config produces a higher PCC with torch
    than a lower-accuracy compute kernel config.
    """

    if device.core_grid.y == 7:
        pytest.skip()

    torch.manual_seed(0)
    input_shape = (N, C, H, W)
    grid_size = ttnn.CoreGrid(y=8, x=8)

    # Execute torch group_norm
    torch_input_tensor = torch.rand(input_shape, dtype=torch.float32)
    torch_output_tensor = torch.nn.functional.group_norm(torch_input_tensor, num_groups)
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    # Generate input mask
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_size.y)
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT8_B,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # Generate shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = N * H * W // grid_size.x, C // grid_size.y
    shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
    sharded_mem_config = ttnn.MemoryConfig(
        ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
    )

    # Helper function to execute group_norm for a given compute config
    def do_group_norm_for_config(compute_config):
        tt_input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
        tt_input_tensor = ttnn.from_torch(
            tt_input_tensor,
            dtype=ttnn.DataType.BFLOAT16,
            layout=ttnn.ROW_MAJOR_LAYOUT,
            device=device,
            memory_config=sharded_mem_config,
        )

        tt_output_tensor = ttnn.group_norm(
            tt_input_tensor,
            num_groups=num_groups,
            input_mask=input_mask_tensor,
            memory_config=sharded_mem_config,
            core_grid=grid_size,
            compute_kernel_config=compute_config,
        )
        tt_output_tensor_host = ttnn.from_device(tt_output_tensor)
        tt_output_tensor_host = ttnn.to_torch(tt_output_tensor_host)

        ttnn.deallocate(tt_input_tensor)
        ttnn.deallocate(tt_output_tensor)

        return tt_output_tensor_host

    # Execute low-accuracy groupnorm
    config_low = ttnn.WormholeComputeKernelConfig(
        math_fidelity=ttnn.MathFidelity.LoFi,
        math_approx_mode=True,
        fp32_dest_acc_en=False,
        packer_l1_acc=False,
    )
    tt_output_low = do_group_norm_for_config(config_low)
    _, pcc_low = comp_pcc(torch_output_tensor, tt_output_low)

    # Execute high-accuracy groupnorm
    config_high = ttnn.WormholeComputeKernelConfig(
        math_fidelity=ttnn.MathFidelity.HiFi4,
        math_approx_mode=False,
        fp32_dest_acc_en=True,
        packer_l1_acc=False,
    )
    tt_output_high = do_group_norm_for_config(config_high)
    _, pcc_high = comp_pcc(torch_output_tensor, tt_output_high)

    # Verify that the higher-accuracy config is closer to torch
    assert pcc_high > pcc_low, "High-accuracy config should have higher PCC than low-accuracy config"


@pytest.mark.parametrize(
    "N, C, H, W, num_groups, shard, eps",
    [
        (1, 256, 12, 40, 16, "BS", 1e-5),
        (1, 256, 24, 80, 16, "HS", 1e-5),
        (1, 256, 48, 160, 16, "HS", 1e-5),
        (1, 512, 12, 40, 16, "BS", 1e-5),
        (1, 64, 96, 320, 16, "HS", 1e-5),
    ],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 0}], indirect=True)
@run_for_blackhole("blackhole specific tests")
def test_group_norm_oft(device, N, C, H, W, num_groups, shard, eps):
    assert C % num_groups == 0, "Number of channels must be divisible by number of groups"

    skip_if_not_blackhole_20_cores(device)
    compute_grid = device.compute_with_storage_grid_size()
    grid_size = ttnn.CoreGrid(y=compute_grid.y, x=compute_grid.x)
    # Generate torch tensor
    torch.manual_seed(0)
    torch_input_tensor = torch.rand((N, C, H, W), dtype=torch.bfloat16)
    torch_weight = torch.rand((C,), dtype=torch.bfloat16)
    torch_bias = torch.rand((C,), dtype=torch.bfloat16)
    # Execute torch group_norm
    torch_output_tensor = torch.nn.functional.group_norm(
        torch_input_tensor, num_groups, weight=torch_weight, bias=torch_bias, eps=eps
    )
    torch_output_tensor = torch_output_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)

    input_tensor = torch_input_tensor.permute(0, 2, 3, 1).view(N, 1, W * H, C)
    input_tensor = ttnn.from_torch(
        input_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    # Generate input mask
    if shard == "HS":
        grid_x = grid_size.x * grid_size.y
        grid_y = 1
    else:
        grid_x = grid_size.x
        grid_y = grid_size.y
    input_mask_tensor = ttnn.create_group_norm_input_mask(C, num_groups, grid_y)
    input_mask_tensor = ttnn.from_torch(
        input_mask_tensor,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.TILE_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    # Generate gamma/beta tensors
    gamma = ttnn.create_group_norm_weight_bias_rm(torch_weight, C, grid_y)
    beta = ttnn.create_group_norm_weight_bias_rm(torch_bias, C, grid_y)

    gamma_t = ttnn.from_torch(
        gamma,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    beta_t = ttnn.from_torch(
        beta,
        dtype=ttnn.DataType.BFLOAT16,
        layout=ttnn.ROW_MAJOR_LAYOUT,
        device=device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )

    # Generate shard config
    grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1)
    shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
    shard_shape = (H * W) // grid_x, C // grid_y
    if shard == "HS":
        shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
        sharded_mem_config = ttnn.MemoryConfig(
            ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.types.BufferType.L1, shard_spec
        )
    elif shard == "BS":
        shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR)
        sharded_mem_config = ttnn.MemoryConfig(
            ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
        )
    input_tensor = ttnn.to_memory_config(input_tensor, memory_config=sharded_mem_config)

    output_tensor = ttnn.group_norm(
        input_tensor,
        num_groups=num_groups,
        input_mask=input_mask_tensor,
        weight=gamma_t,
        bias=beta_t,
        memory_config=sharded_mem_config,
        core_grid=grid_size,
        epsilon=eps,
    )
    output_tensor = ttnn.to_torch(output_tensor)
    assert_with_pcc(torch_output_tensor, output_tensor, 0.999)
